{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepping YOLOv5Face model for use in Ente\n",
    "\n",
    "[Paper](https://arxiv.org/abs/2105.12931) | [Github](https://github.com/deepcam-cn/yolov5-face)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Setting up Pytorch weights and source code"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Please manually put the Pytorch .pt weights in the `pytorch_weights` directory. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_weights_path = \"pytorch_weights/yolov5s_face.pt\"\n",
    "models_path = \"onnx_models/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!mkdir yoloface_repo\n",
    "%cd yoloface_repo\n",
    "!git clone https://github.com/deepcam-cn/yolov5-face.git\n",
    "%cd ..\n",
    "!cp -r yoloface_repo/yolov5-face/models/ models/\n",
    "!cp -r yoloface_repo/yolov5-face/utils/ utils/\n",
    "!rm -rf yoloface_repo"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Libraries\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "from PIL import Image\n",
    "import json\n",
    "import numpy as np\n",
    "import onnx\n",
    "import onnxruntime as ort\n",
    "print(ort.__version__)\n",
    "\n",
    "# Source code\n",
    "from models.common import Conv, ShuffleV2Block\n",
    "from models.experimental import attempt_load\n",
    "from utils.activations import Hardswish, SiLU\n",
    "from utils.general import set_logging, check_img_size"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export to ONNX"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "onnx_opset = 18\n",
    "img_size = [640, 640]\n",
    "batch_size = 1\n",
    "dynamic_shapes = False\n",
    "\n",
    "# Load PyTorch model\n",
    "model = attempt_load(\n",
    "    model_weights_path, map_location=torch.device(\"cpu\")\n",
    ")  # load FP32 model\n",
    "delattr(model.model[-1], \"anchor_grid\")\n",
    "model.model[-1].anchor_grid = [\n",
    "    torch.zeros(1)\n",
    "] * 3  # nl=3 number of detection layers\n",
    "model.model[-1].export_cat = True\n",
    "model.eval()\n",
    "labels = model.names\n",
    "\n",
    "# Checks\n",
    "gs = int(max(model.stride))  # grid size (max stride)\n",
    "img_size = [\n",
    "    check_img_size(x, gs) for x in img_size\n",
    "]  # verify img_size are gs-multiples\n",
    "\n",
    "# Test input\n",
    "img = torch.zeros(batch_size, 3, *img_size)\n",
    "\n",
    "# Update model\n",
    "for k, m in model.named_modules():\n",
    "    m._non_persistent_buffers_set = set()  # pytorch 1.6.0 compatibility\n",
    "    if isinstance(m, Conv):  # assign export-friendly activations\n",
    "        if isinstance(m.act, nn.Hardswish):\n",
    "            m.act = Hardswish()\n",
    "        elif isinstance(m.act, nn.SiLU):\n",
    "            m.act = SiLU()\n",
    "    if isinstance(m, ShuffleV2Block):  # shufflenet block nn.SiLU\n",
    "        for i in range(len(m.branch1)):\n",
    "            if isinstance(m.branch1[i], nn.SiLU):\n",
    "                m.branch1[i] = SiLU()\n",
    "        for i in range(len(m.branch2)):\n",
    "            if isinstance(m.branch2[i], nn.SiLU):\n",
    "                m.branch2[i] = SiLU()\n",
    "y = model(img)  # dry run\n",
    "\n",
    "# ONNX export\n",
    "print(\"\\nStarting ONNX export with onnx %s...\" % onnx.__version__)\n",
    "onnx_model_export_path = models_path + model_weights_path.replace(\".pt\", \".onnx\").split('/')[-1]\n",
    "model.fuse()  \n",
    "input_names = [\"input\"]\n",
    "output_names = [\"output\"]\n",
    "torch.onnx.export(\n",
    "    model,\n",
    "    img,\n",
    "    onnx_model_export_path,\n",
    "    verbose=False,\n",
    "    opset_version=onnx_opset,\n",
    "    input_names=input_names,\n",
    "    output_names=output_names,\n",
    "    dynamic_axes=(\n",
    "        {\"input\": {0: \"batch\"}, \"output\": {0: \"batch\"}} if dynamic_shapes else None\n",
    "    ),\n",
    ")\n",
    "\n",
    "# Checks\n",
    "onnx_model = onnx.load(onnx_model_export_path)  # load onnx model\n",
    "onnx.checker.check_model(onnx_model)  # check onnx model\n",
    "\n",
    "# onnx infer\n",
    "providers = [\"CPUExecutionProvider\"]\n",
    "session = ort.InferenceSession(onnx_model_export_path, providers=providers)\n",
    "im = img.cpu().numpy().astype(np.float32)  # torch to numpy\n",
    "y_onnx = session.run(\n",
    "    [session.get_outputs()[0].name], {session.get_inputs()[0].name: im}\n",
    ")[0]\n",
    "print(\"pred's shape is \", y_onnx.shape)\n",
    "print(\"max(|torch_pred - onnx_pred|） =\", abs(y.cpu().numpy() - y_onnx).max())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm -rf models/\n",
    "!rm -rf utils/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Altering ONNX model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add preprocessing inside model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "from onnxruntime_extensions.tools.pre_post_processing import PrePostProcessor, create_named_value, Resize, ImageBytesToFloat, Unsqueeze, Debug, LetterBox, ChannelsLastToChannelsFirst"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "inputs = [create_named_value(\"input_to_preprocess\", onnx.TensorProto.UINT8, [\"H\", \"W\", \"C\"])]\n",
    "\n",
    "pipeline = PrePostProcessor(inputs, onnx_opset)\n",
    "\n",
    "pipeline.add_pre_processing(\n",
    "    [\n",
    "        Resize(640, layout= \"HWC\", policy=\"not_larger\"), # Resize to 640, maintaining aspect ratio and letting largest dimension not exceed 640 (so smaller dimension will be <= 640)\n",
    "        # Debug(),\n",
    "        LetterBox((640, 640), layout=\"HWC\", fill_value=114),  # Add padding to make the image actually always 640x640,\n",
    "        ChannelsLastToChannelsFirst(),  # Convert to CHW\n",
    "        # Debug(),\n",
    "        ImageBytesToFloat(),  # Convert to float in range 0..1 by dividing uint8 values by 255\n",
    "        # Debug(),\n",
    "        Unsqueeze([0]),  # add batch, CHW --> 1CHW\n",
    "        # Debug(),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# pipeline.add_post_processing()\n",
    "onnx_model_prepro = pipeline.run(onnx_model)\n",
    "onnx.checker.check_model(onnx_model_prepro)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "To debug and visually inspect the preprocessing, please uncomment the `Debug()` statements in above block and run it again, and then uncomment and run the code in the block below:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# onnx.save(onnx_model_prepro, \"yolov5s_face_prepro.onnx\")\n",
    "\n",
    "# image_singapore = Image.open(\"../data/singapore.jpg\").convert('RGB')\n",
    "# image_singapore_onnx = np.array(image_singapore)\n",
    "# print(image_singapore_onnx.shape)\n",
    "# print(type(image_singapore_onnx))\n",
    "# print(image_singapore_onnx.dtype)\n",
    "\n",
    "# ort_session = ort.InferenceSession(\"yolov5s_face_prepro.onnx\")\n",
    "# test = ort_session.run(None, {\"input_to_preprocess\": image_singapore_onnx})\n",
    "\n",
    "# preprocessed = test[4]\n",
    "# print(preprocessed.shape)\n",
    "# print(type(preprocessed))\n",
    "\n",
    "# # import matplotlib#.pyplot as plt\n",
    "# from IPython.display import display\n",
    "# # matplotlib.use('TkAgg')\n",
    "\n",
    "# displayable_array = preprocessed.reshape(3, 640, 640).transpose((1, 2, 0))\n",
    "# # Display the image\n",
    "# # matplotlib.pyplot.imshow(displayable_array)\n",
    "# # matplotlib.pyplot.axis('off')  \n",
    "# # matplotlib.pyplot.show()\n",
    "# display(Image.fromarray((displayable_array * 255).astype(np.uint8)))\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add slice operator for use of RGBA input"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a new input with flexible channel dimension\n",
    "new_input = onnx.helper.make_tensor_value_info(\n",
    "    name=\"input\",\n",
    "    elem_type=onnx.TensorProto.UINT8,\n",
    "    shape=[\"H\", \"W\", \"C\"],  \n",
    ")\n",
    "\n",
    "# Create constant tensors for starts, ends, and axes and use them to create a Slice node\n",
    "starts_tensor = onnx.helper.make_tensor(\n",
    "    name=\"starts\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([0], dtype=np.int64)\n",
    ")\n",
    "ends_tensor = onnx.helper.make_tensor(\n",
    "    name=\"ends\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([3], dtype=np.int64)\n",
    ")\n",
    "axes_tensor = onnx.helper.make_tensor(\n",
    "    name=\"axes\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([2], dtype=np.int64)\n",
    ")\n",
    "slice_node = onnx.helper.make_node(\n",
    "    \"Slice\",\n",
    "    inputs=[\"input\", \"starts\", \"ends\", \"axes\"],\n",
    "    outputs=[\"sliced_input\"],\n",
    "    name=\"slice_rgba_input_node\",\n",
    ")\n",
    "# Combine initializers\n",
    "initializers = [starts_tensor, ends_tensor, axes_tensor] + list(onnx_model_prepro.graph.initializer)\n",
    "\n",
    "# Get the name of the original input\n",
    "original_input_name = onnx_model_prepro.graph.input[0].name\n",
    "\n",
    "# Make new graph by adding the new input and Slice node to the old graph\n",
    "graph = onnx.helper.make_graph(\n",
    "    [slice_node] + list(onnx_model_prepro.graph.node),  # Prepend Slice node to existing nodes\n",
    "    onnx_model_prepro.graph.name,\n",
    "    [new_input] + list(onnx_model_prepro.graph.input)[1:],  # Replace first input, keep others\n",
    "    list(onnx_model_prepro.graph.output),\n",
    "    initializer=initializers,\n",
    "    value_info=onnx_model_prepro.graph.value_info,\n",
    ")\n",
    "\n",
    "# Create the new model\n",
    "onnx_model_rgba = onnx.helper.make_model(\n",
    "    graph,\n",
    "    opset_imports=[onnx.helper.make_opsetid(\"\", onnx_opset)]\n",
    ")\n",
    "\n",
    "# Update the input names in the rest of the model\n",
    "for node in onnx_model_rgba.graph.node:\n",
    "    for i, input_name in enumerate(node.input):\n",
    "        if input_name == original_input_name:\n",
    "            node.input[i] = \"sliced_input\"\n",
    "\n",
    "# Save the new model\n",
    "onnx.checker.check_model(onnx_model_rgba)\n",
    "onnx_model_rgba_path = onnx_model_export_path[:-5] + \"_rgba.onnx\"\n",
    "onnx.save(onnx_model_rgba, onnx_model_rgba_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = Image.open(\"../data/man.jpeg\").convert('RGBA')\n",
    "# image_onnx = np.array(image)\n",
    "# print(image_onnx.shape)\n",
    "# print(type(image_onnx))\n",
    "# print(image_onnx.dtype)\n",
    "\n",
    "# ort_session = ort.InferenceSession(\"yolov5s_face_rgba.onnx\")\n",
    "# test = ort_session.run(None, {\"input\": image_onnx})\n",
    "# print(test[0].shape)\n",
    "# scores_output = test[0][:,:,4]\n",
    "# print(f\"Highest score: {scores_output.max()}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add post-processing inside the model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's first rename the output of the model so we can name the post-processed output as `output`.\n",
    "Then we have to split `[1, 25200, 16]` into `[1, 25200, 4]`, `[1, 25200, 1]`, `[1, 25200, 11]` (i.e. `[1, detections, bbox]`, `[1, detections, score]`, `[1, detections, landmarks]`) named `boxes`, `scores`, `masks`."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add a Split operator at the end so that it be used with the SelectBestBoundingBoxesByNMS operator\n",
    "num_det = 25200\n",
    "graph = onnx_model_rgba.graph\n",
    "\n",
    "# Let's first rename the output of the model so we can name the post-processed output as `output`\n",
    "for node in onnx_model_rgba.graph.node:\n",
    "    for i, output_name in enumerate(node.output):\n",
    "        if output_name == \"output\":\n",
    "            node.output[i] = \"og_output\"\n",
    "og_output = onnx.helper.make_tensor_value_info(\n",
    "    name=\"og_output\",\n",
    "    elem_type=onnx.TensorProto.FLOAT,\n",
    "    shape=[1, num_det, 16],  \n",
    ")\n",
    "\n",
    "# Create the split node\n",
    "boxes_output = onnx.helper.make_tensor_value_info(\n",
    "    name=\"boxes_unsqueezed\",\n",
    "    elem_type=onnx.TensorProto.FLOAT,\n",
    "    shape=[1, num_det, 4],  \n",
    ")\n",
    "scores_output = onnx.helper.make_tensor_value_info(\n",
    "    name=\"scores_unsqueezed\",\n",
    "    elem_type=onnx.TensorProto.FLOAT,\n",
    "    shape=[1, num_det, 1],  \n",
    ")\n",
    "masks_output = onnx.helper.make_tensor_value_info(\n",
    "    name=\"masks_unsqueezed\",\n",
    "    elem_type=onnx.TensorProto.FLOAT,\n",
    "    shape=[1, num_det, 11],  \n",
    ")\n",
    "splits_tensor = onnx.helper.make_tensor(\n",
    "    name=\"splits\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[3],\n",
    "    vals=np.array([4, 1, 11], dtype=np.int64)\n",
    ")\n",
    "split_node = onnx.helper.make_node(\n",
    "        \"Split\",\n",
    "        inputs=[\"og_output\", \"splits\"],\n",
    "        outputs=[\"boxes_unsqueezed\", \"scores_unsqueezed\", \"masks_unsqueezed\"],\n",
    "        name=\"split_og_output\",\n",
    "        axis=2,\n",
    ")\n",
    "\n",
    "# Combine initializers\n",
    "initializers = list(graph.initializer) + [splits_tensor]\n",
    "\n",
    "# Make new graph by adding the new outputs and Split node to the old graph\n",
    "graph = onnx.helper.make_graph(\n",
    "    list(graph.node) + [split_node],  # Append split node to existing nodes\n",
    "    graph.name,\n",
    "    list(graph.input), \n",
    "    [boxes_output, scores_output, masks_output],\n",
    "    initializer=initializers,\n",
    "    value_info=graph.value_info,\n",
    ")\n",
    "\n",
    "# Create the new model\n",
    "onnx_model_split = onnx.helper.make_model(\n",
    "    graph,\n",
    "    opset_imports=[onnx.helper.make_opsetid(\"\", onnx_opset)]\n",
    ")\n",
    "\n",
    "# Save the new model\n",
    "onnx.checker.check_model(onnx_model_split)\n",
    "onnx_model_split_path = onnx_model_export_path[:-5] + \"_split.onnx\"\n",
    "onnx.save(onnx_model_split, onnx_model_split_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we can run NMS on these splitted outputs using `NonMaxSuppression` operator"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_det = 25200\n",
    "graph = onnx_model_split.graph\n",
    "nodes = list(graph.node)\n",
    "outputs = list(graph.output)\n",
    "initializers = list(graph.initializer)\n",
    "original_output = graph.output[0]\n",
    "\n",
    "# Create the Transpose node for the scores (since NMS requires the scores to be in the middle dimension for some reason)\n",
    "transpose_node_score = onnx.helper.make_node(\n",
    "        \"Transpose\",\n",
    "        inputs=[\"scores_unsqueezed\"],\n",
    "        outputs=[\"scores_transposed\"],\n",
    "        name=\"transpose_scores\",\n",
    "        perm=[0, 2, 1],\n",
    ")\n",
    "nodes.append(transpose_node_score)\n",
    "\n",
    "# Create the NMS node\n",
    "nms_indices = onnx.helper.make_tensor_value_info(\"nms_indices\", onnx.TensorProto.INT64, shape=[\"detections\", 3])\n",
    "max_output = onnx.helper.make_tensor(\"max_output\",onnx.TensorProto.INT64, [1], np.array([100], dtype=np.int64))\n",
    "iou_threshold = onnx.helper.make_tensor(\"iou_threshold\",onnx.TensorProto.FLOAT, [1], np.array([0.4], dtype=np.float32))\n",
    "score_threshold = onnx.helper.make_tensor(\"score_threshold\",onnx.TensorProto.FLOAT, [1], np.array([0.6], dtype=np.float32))\n",
    "initializers = initializers + [max_output, iou_threshold, score_threshold]\n",
    "nms_node = onnx.helper.make_node(\n",
    "        \"NonMaxSuppression\",\n",
    "        inputs=[\"boxes_unsqueezed\", \"scores_transposed\", \"max_output\", \"iou_threshold\", \"score_threshold\"],\n",
    "        outputs=[\"nms_indices\"],\n",
    "        name=\"perform_nms\",\n",
    "        center_point_box=1,\n",
    ")\n",
    "nodes.append(nms_node)\n",
    "outputs.append(nms_indices)\n",
    "\n",
    "# Make new graph by adding the new outputs and Split node to the old graph\n",
    "graph = onnx.helper.make_graph(\n",
    "    nodes,\n",
    "    graph.name,\n",
    "    list(graph.input), \n",
    "    outputs,\n",
    "    initializer=initializers,\n",
    "    value_info=graph.value_info,\n",
    ")\n",
    "\n",
    "# Create the new model\n",
    "onnx_model_nms = onnx.helper.make_model(\n",
    "    graph,\n",
    "    opset_imports=[onnx.helper.make_opsetid(\"\", onnx_opset)]\n",
    ")\n",
    "\n",
    "# Save the new model\n",
    "onnx.checker.check_model(onnx_model_nms)\n",
    "onnx_model_nms_path = onnx_model_export_path[:-5] + \"_nms.onnx\"\n",
    "onnx.save(onnx_model_nms, onnx_model_nms_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = Image.open(\"../data/man.jpeg\").convert('RGBA')\n",
    "# image_onnx = np.array(image)\n",
    "\n",
    "# ort_session = ort.InferenceSession(\"yolov5s_face_nms.onnx\")\n",
    "# test = ort_session.run(None, {\"input\": image_onnx})\n",
    "# print(test[3].shape)\n",
    "# print(test[3])\n",
    "# print(test[1][0, 24129, 0])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we need to add some Squeeze, Slice and Gather nodes so handle the NMS given indices properly. The goal is that the final output is a very simple array of shape `(detections, 16)` of only the relevant detections."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_det = 25200\n",
    "graph = onnx_model_nms.graph\n",
    "nodes = list(graph.node)\n",
    "outputs = list(graph.output)\n",
    "initializers = list(graph.initializer)\n",
    "original_output = graph.output[0]\n",
    "\n",
    "# Create Slide node to slice the NMS indices from (detections, 3) to (detections, 1) by taking the third column\n",
    "sliced_indices = onnx.helper.make_tensor_value_info(\"sliced_indices\", onnx.TensorProto.INT64, shape=[\"detections\", 1])\n",
    "outputs.append(sliced_indices)\n",
    "starts_slice_tensor = onnx.helper.make_tensor(\n",
    "    name=\"starts_slice_tensor\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([2], dtype=np.int64)\n",
    ")\n",
    "ends_slice_tensor = onnx.helper.make_tensor(\n",
    "    name=\"ends_slice_tensor\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([3], dtype=np.int64)\n",
    ")\n",
    "axes_slice_tensor = onnx.helper.make_tensor(\n",
    "    name=\"axes_slice_tensor\",\n",
    "    data_type=onnx.TensorProto.INT64,\n",
    "    dims=[1],\n",
    "    vals=np.array([1], dtype=np.int64)\n",
    ")\n",
    "initializers = initializers + [starts_slice_tensor, ends_slice_tensor, axes_slice_tensor]\n",
    "slice_node = onnx.helper.make_node(\n",
    "    \"Slice\",\n",
    "    inputs=[\"nms_indices\", \"starts_slice_tensor\", \"ends_slice_tensor\", \"axes_slice_tensor\"],\n",
    "    outputs=[\"sliced_indices\"],\n",
    "    name=\"slice_nms_indices\",\n",
    ")\n",
    "nodes.append(slice_node)\n",
    "\n",
    "# Create Squeeze node to squeeze the sliced indices\n",
    "squeezed_indices = onnx.helper.make_tensor_value_info(\"squeezed_indices\", onnx.TensorProto.INT64, shape=[\"detections\"])\n",
    "outputs.append(squeezed_indices)\n",
    "squeeze_slice_tensor = onnx.helper.make_tensor(\"squeeze_slice_axis\",onnx.TensorProto.INT64, [1], np.array([1], dtype=np.int64))\n",
    "initializers.append(squeeze_slice_tensor)\n",
    "squeeze_slice_node = onnx.helper.make_node(\n",
    "        \"Squeeze\",\n",
    "        inputs=[\"sliced_indices\", \"squeeze_slice_axis\"],\n",
    "        outputs=[\"squeezed_indices\"],\n",
    "        name=\"squeeze_sliced_indices\",\n",
    ")\n",
    "nodes.append(squeeze_slice_node)\n",
    "\n",
    "# Create Squeeze node to squeeze the original output\n",
    "squeezed_output = onnx.helper.make_tensor_value_info(\"squeezed_output\", onnx.TensorProto.FLOAT, shape=[25200, 16])\n",
    "outputs.append(squeezed_output)\n",
    "squeeze_tensor = onnx.helper.make_tensor(\"squeeze_axis\",onnx.TensorProto.INT64, [1], np.array([0], dtype=np.int64))\n",
    "initializers.append(squeeze_tensor)\n",
    "squeeze_node = onnx.helper.make_node(\n",
    "        \"Squeeze\",\n",
    "        inputs=[\"og_output\", \"squeeze_axis\"],\n",
    "        outputs=[\"squeezed_output\"],\n",
    "        name=\"squeeze_output\",\n",
    ")\n",
    "nodes.append(squeeze_node)\n",
    "\n",
    "\n",
    "# Create Gather node to gather the relevant NMS indices from the original output\n",
    "postpro_output = onnx.helper.make_tensor_value_info(\"output\", onnx.TensorProto.FLOAT, shape=[\"detections\", 16])\n",
    "outputs.append(postpro_output)\n",
    "gather_node = onnx.helper.make_node(\n",
    "    \"Gather\",\n",
    "    inputs=[\"squeezed_output\", \"squeezed_indices\"],\n",
    "    outputs=[\"output\"],\n",
    "    name=\"gather_output\",\n",
    ")\n",
    "nodes.append(gather_node)\n",
    "\n",
    "\n",
    "# Make the new graph\n",
    "graph = onnx.helper.make_graph(\n",
    "    nodes,\n",
    "    graph.name,\n",
    "    list(graph.input), \n",
    "    [postpro_output],\n",
    "    initializer=initializers,\n",
    "    value_info=graph.value_info,\n",
    ")\n",
    "\n",
    "# Create the new model\n",
    "onnx_model_prepostpro = onnx.helper.make_model(\n",
    "    graph,\n",
    "    opset_imports=[onnx.helper.make_opsetid(\"\", onnx_opset)]\n",
    ")\n",
    "\n",
    "# Save the new model\n",
    "onnx.checker.check_model(onnx_model_prepostpro)\n",
    "onnx_model_prepostpro_path = onnx_model_export_path[:-5] + \"_prepostpro.onnx\"\n",
    "onnx.save(onnx_model_prepostpro, onnx_model_prepostpro_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = Image.open(\"../data/people.jpeg\").convert('RGBA')\n",
    "# image_onnx = np.array(image)\n",
    "\n",
    "# ort_session = ort.InferenceSession(\"yolov5s_face_prepostpro.onnx\")\n",
    "# test = ort_session.run(None, {\"input\": image_onnx})\n",
    "# test[0].shape"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Optimize model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "# define path og model and sim model\n",
    "onnx_model_sim_path = onnx_model_export_path[:-5] + f\"_opset{onnx_opset}_rgba_sim.onnx\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Simplify the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!onnxsim {onnx_model_prepostpro_path} {onnx_model_sim_path}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [],
   "source": [
    "# !onnxsim yolov5s_face_prepostpro.onnx yolov5s_face_opset18_rgba_sim.onnx"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Optimize the graph"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [],
   "source": [
    "opt_sess_options = ort.SessionOptions()\n",
    "\n",
    "opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL\n",
    "opt_sess_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC\n",
    "\n",
    "onnx_model_opt_path = onnx_model_export_path[:-5] + f\"_opset{onnx_opset}_rgba_opt.onnx\"\n",
    "opt_sess_options.optimized_model_filepath = onnx_model_opt_path\n",
    "\n",
    "opt_session = ort.InferenceSession(onnx_model_sim_path, opt_sess_options)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prevent splits initializer issue\n",
    "\n",
    "For some weird reason the model can give issues on iOS when there's an initializer named \"splits\". \n",
    "So to prevent that we check and rename any such initializer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [],
   "source": [
    "current_model = onnx.load(onnx_model_opt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def find_duplicates(name_list):\n",
    "    seen = set()\n",
    "    duplicates = set()\n",
    "    \n",
    "    for name in name_list:\n",
    "        if name in seen:\n",
    "            duplicates.add(name)\n",
    "        else:\n",
    "            seen.add(name)\n",
    "    \n",
    "    return list(duplicates)\n",
    "\n",
    "# Get the list of initializers\n",
    "initializers = current_model.graph.initializer\n",
    "init_names = [init.name for init in initializers]\n",
    "\n",
    "# If you want to store the initializers and their names in a dictionary\n",
    "initializer_dict = {init.name: init for init in initializers}\n",
    "init_names = [init.name for init in initializers]\n",
    "\n",
    "print(f\"splits initializer: \\n {initializer_dict[\"splits\"]}\")\n",
    "\n",
    "duplicate_names = find_duplicates(init_names)\n",
    "\n",
    "print(\"Duplicate names:\", duplicate_names)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "metadata": {},
   "outputs": [],
   "source": [
    "def rename_initializer(model, old_name, new_name):\n",
    "    for initializer in model.graph.initializer:\n",
    "        if initializer.name == old_name:\n",
    "            initializer.name = new_name\n",
    "            break\n",
    "    \n",
    "    # Update any references to this initializer in the graph inputs\n",
    "    for input in model.graph.input:\n",
    "        if input.name == old_name:\n",
    "            input.name = new_name\n",
    "    \n",
    "    # Update references in nodes\n",
    "    for node in model.graph.node:\n",
    "        for i, input_name in enumerate(node.input):\n",
    "            if input_name == old_name:\n",
    "                node.input[i] = new_name"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [],
   "source": [
    "rename_initializer(current_model, \"splits\", \"splits_initializer_unique\")\n",
    "\n",
    "# Save the modified model\n",
    "onnx_model_opt_with_splits_path = onnx_model_opt_path\n",
    "onnx_model_opt_path = onnx_model_opt_path[:-5] + \"_nosplits.onnx\"\n",
    "onnx.save(current_model, onnx_model_opt_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Add metadata to model\n",
    "\n",
    "https://onnx.ai/onnx/intro/python.html#opset-and-metadata"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "metadata": {},
   "outputs": [],
   "source": [
    "new_yolo_face_model = onnx.load(onnx_model_opt_path)\n",
    "new_yolo_face_model.producer_name = \"EnteYOLOv5Face\"\n",
    "new_yolo_face_model.doc_string = \"YOLOv5 Face detector with built-in pre- and post-processing. Accepts both RGB and RGBA raw bytes input (uint8) in HWC format. Outputs the relevant detections in the format (detections, 16) where the first 4 values are the bounding box coordinates, the fifth is the confidence score, and the rest are the landmarks.\"\n",
    "new_yolo_face_model.graph.doc_string = \"\"\n",
    "new_yolo_face_model.graph.name = \"SliceRGB+Resize+LetterBox+ToFloat+Unsqueeze+YOLOv5Face+NMS+Slice+Gather\"\n",
    "onnx.save(new_yolo_face_model, onnx_model_opt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "!rm {onnx_model_export_path}\n",
    "!rm {onnx_model_rgba_path}\n",
    "!rm {onnx_model_split_path}\n",
    "!rm {onnx_model_nms_path}\n",
    "!rm {onnx_model_prepostpro_path}\n",
    "!rm {onnx_model_sim_path}\n",
    "!rm {onnx_model_opt_with_splits_path}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Tune some settings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 26,
   "metadata": {},
   "outputs": [],
   "source": [
    "# from tqdm import tqdm\n",
    "# import time"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "# image = Image.open(\"../data/people.jpeg\").convert('RGBA')\n",
    "# image_onnx = np.array(image)\n",
    "# time_test_size = 500\n",
    "\n",
    "# sess_options1 = ort.SessionOptions()\n",
    "# sess_options1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED\n",
    "# # sess_options.enable_profiling = True\n",
    "# # sess_options.log_severity_level = 0 # Verbose\n",
    "# sess_options1.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL\n",
    "# ort_session1 = ort.InferenceSession(onnx_model_opt_path, sess_options1)\n",
    "\n",
    "# begin_time_1 = time.time()\n",
    "# for i in tqdm(range(time_test_size)):\n",
    "#     _ = ort_session1.run(None, {\"input\": image_onnx})\n",
    "# end_time_1 = time.time()\n",
    "# time_1 = end_time_1 - begin_time_1\n",
    "\n",
    "\n",
    "# sess_options2 = ort.SessionOptions()\n",
    "# sess_options2.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED\n",
    "# # sess_options.enable_profiling = True\n",
    "# # sess_options.log_severity_level = 0 # Verbose\n",
    "# sess_options2.inter_op_num_threads = 4\n",
    "# # sess_options2.intra_op_num_threads = 4\n",
    "# sess_options2.execution_mode = ort.ExecutionMode.ORT_PARALLEL\n",
    "# ort_session2 = ort.InferenceSession(onnx_model_opt_path, sess_options2, providers=[\"CPUExecutionProvider\"])\n",
    "\n",
    "# begin_time_2 = time.time()\n",
    "# for i in tqdm(range(time_test_size)):\n",
    "#     _ = ort_session2.run(None, {\"input\": image_onnx})\n",
    "# end_time_2 = time.time()\n",
    "# time_2 = end_time_2 - begin_time_2\n",
    "\n",
    "# print(f\"Time for first execution: {time_1}\")\n",
    "# print(f\"Time for second execution: {time_2}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "So lessons:\n",
    "1. Use sequential execution\n",
    "2. Use extended optimizations\n",
    "3. Number of inter op doesn't have significant impact\n",
    "4. Number of intra op doesn't have significant impact"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "One final test:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image = Image.open(\"../data/man.jpeg\").convert('RGBA')\n",
    "imageWidth, imageHeight = image.size\n",
    "inputWidth, inputHeight = 640, 640\n",
    "print(imageWidth, imageHeight)\n",
    "image_onnx = np.array(image)\n",
    "\n",
    "sess_options1 = ort.SessionOptions()\n",
    "sess_options1.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED\n",
    "# sess_options.enable_profiling = True\n",
    "# sess_options.log_severity_level = 0 # Verbose\n",
    "# sess_options1.execution_mode = ort.ExecutionMode.ORT_SEQUENTIAL\n",
    "ort_session = ort.InferenceSession(onnx_model_opt_path)\n",
    "raw_detection = ort_session.run(None, {\"input\": image_onnx})[0][0]\n",
    "print(raw_detection.shape)\n",
    "raw_detection"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 29,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image, ImageDraw\n",
    "from IPython.display import display\n",
    "\n",
    "def display_face_detection(image_path, face_box, landmarks):\n",
    "    # Open the image\n",
    "    img = Image.open(image_path)\n",
    "    \n",
    "    # Create a draw object\n",
    "    draw = ImageDraw.Draw(img)\n",
    "    \n",
    "    # Draw the bounding box\n",
    "    draw.rectangle(face_box, outline=\"red\", width=2)\n",
    "    \n",
    "    # Draw the landmark points\n",
    "    for point in landmarks:\n",
    "        x, y = point\n",
    "        radius = 3\n",
    "        draw.ellipse([x-radius, y-radius, x+radius, y+radius], fill=\"blue\")\n",
    "    \n",
    "    # Display the image\n",
    "    display(img)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "def correct_detection_and_display(image_path, raw_detection, imageWidth, imageHeight, inputWidth, inputHeight):\n",
    "\n",
    "    # Create the raw relative bounding box and landmarks\n",
    "    box = [0, 0, 0, 0]\n",
    "    box[0] = (raw_detection[0] - raw_detection[2] / 2) / inputWidth\n",
    "    box[1] = (raw_detection[1] - raw_detection[3] / 2) / inputHeight\n",
    "    box[2] = (raw_detection[0] + raw_detection[2] / 2) / inputWidth\n",
    "    box[3] = (raw_detection[1] + raw_detection[3] / 2) / inputHeight\n",
    "    landmarks = [(0, 0), (0, 0), (0, 0), (0, 0), (0, 0)]\n",
    "    i = 0\n",
    "    for x, y in zip(raw_detection[5:15:2], raw_detection[6:15:2]):\n",
    "        landmarks[i] = (x / inputWidth, y / inputHeight)\n",
    "        i += 1\n",
    "\n",
    "    # Correct the bounding box and landmarks for letterboxing during preprocessing\n",
    "    scale = min(inputWidth / imageWidth, inputHeight / imageHeight)\n",
    "    scaledWidth = round(imageWidth * scale)\n",
    "    scaledHeight = round(imageHeight * scale)\n",
    "    print(f\"scaledWidth: {scaledWidth}, scaledHeight: {scaledHeight}\")\n",
    "\n",
    "    halveDiffX = (inputWidth - scaledWidth) / 2\n",
    "    halveDiffY = (inputHeight - scaledHeight) / 2\n",
    "    print(f\"halveDiffX: {halveDiffX}, halveDiffY: {halveDiffY}\")\n",
    "    scaleX = inputHeight / scaledWidth\n",
    "    scaleY = inputHeight / scaledHeight\n",
    "    translateX = - halveDiffX / inputWidth\n",
    "    translateY = - halveDiffY / inputHeight\n",
    "    print(f\"scaleX: {scaleX}, scaleY: {scaleY}\")\n",
    "    print(f\"translateX: {translateX}, translateY: {translateY}\")\n",
    "\n",
    "    box[0] = (box[0] + translateX) * scaleX\n",
    "    box[1] = (box[1] + translateY) * scaleY\n",
    "    box[2] = (box[2] + translateX) * scaleX\n",
    "    box[3] = (box[3] + translateY) * scaleY\n",
    "\n",
    "    for i in range(5):\n",
    "        landmarks[i] = ((landmarks[i][0] + translateX) * scaleX, (landmarks[i][1] + translateY) * scaleY)\n",
    "\n",
    "    # Convert the bounding box and landmarks to absolute values\n",
    "    box = [box[0] * imageWidth, box[1] * imageHeight, box[2] * imageWidth, box[3] * imageHeight]\n",
    "    landmarks = [(x * imageWidth, y * imageHeight) for x, y in landmarks]\n",
    "\n",
    "    print(\"Bounding box:\", box)\n",
    "    print(\"Landmarks:\", landmarks)\n",
    "\n",
    "    display_face_detection(image_path, box, landmarks)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "image_path = \"../data/man.jpeg\"\n",
    "# face_box = (50, 10, 100, 100)  # (left, top, right, bottom)\n",
    "# landmarks = [\n",
    "#     (30, 30),  # Left eye\n",
    "#     (80, 30),  # Right eye\n",
    "#     (55, 50),  # Nose\n",
    "#     (35, 80),  # Left mouth corner\n",
    "#     (75, 80)   # Right mouth corner\n",
    "# ]\n",
    "\n",
    "correct_detection_and_display(image_path, raw_detection, imageWidth, imageHeight, inputWidth, inputHeight)\n"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "ente_clip",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
